function [X_all,Y_all,D,n,ps] = generate_inputs_quadrant_rule(sample,wave,tcost,covariate,SMC)

% This function generates the inputs for quadrant rule estimation

% Input 
% (1) sample: cross-sectional dataset with electricity consumption, 
% covariates, and propensity scores
% (2) wave: indicates whether the output is wave-specific (=3,6,or 7) or
% for the pooled sample (=0)
% (3) tcost: program cost per household. >0 represents cost savings; =0
% represents kWh reduction 
% (4) covariate: indicates covariate for analysis: income, size, vintage,
% minimum of baseline consumption, maximum of baseline consumption, or
% standard deviation of consumption
% (5)  SMC: =1 use social marginal cost; =0 use retail electricity price

% Output 
% (1) X_all: the list of covariates for EWM analysis, typically baseline 
% consumption & the selected covariates
% (2) Y_all: demeaned and scaled cost of electricity consumption, net of
% program cost
% (3) D: actual Opower treatment status in RCT 
% (4) n: number of households
% (5) ps: propensity scores 

%% Define variables 

if wave==0
    usage_wave = sample.post_avg - sample.pre_avg;
    opower_wave = sample.opower; 
    inc_wave = sample.income;
    size_wave = sample.uni_size;
    vintage_wave = sample.vintage;
    prevuse_wave = round(sample.pre_avg);
    
    if covariate=="min" || covariate=="max" || covariate=="std"
        min_baseline_wave = sample.min_baseline;
        max_baseline_wave = sample.max_baseline;
        std_baseline_wave = sample.std_baseline;
    end   
else
    usage_wave = table2array(sample(sample.opower_paper_wave==wave,{'post_avg'}))-...
            table2array(sample(sample.opower_paper_wave==wave,{'pre_avg'}));
    opower_wave = table2array(sample(sample.opower_paper_wave==wave,{'opower'}));
    inc_wave = table2array(sample(sample.opower_paper_wave==wave,{'income'}));
    size_wave = table2array(sample(sample.opower_paper_wave==wave,{'uni_size'}));
    vintage_wave = table2array(sample(sample.opower_paper_wave==wave,{'vintage'}));
    prevuse_wave = round(table2array(sample(sample.opower_paper_wave==wave,{'pre_avg'})));
    min_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'min_baseline'}));
    max_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'max_baseline'}));
    std_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'std_baseline'}));
end 

%% Y in terms of cost savings or energy conservation 
if (SMC)
    if (tcost)
        Y_all = 0.065.*usage_wave(:,1)+tcost.*opower_wave(:,1);
    else
        Y_all = usage_wave(:,1);
    end
else
    if (tcost)
        Y_all = 0.177.*usage_wave(:,1)+tcost.*opower_wave(:,1);
    else
        Y_all = usage_wave(:,1);
    end
end

if wave==0 % in the pooled sample
    demean_wave = 1;  %=1 demean within wave
    [Y_all] = Y_demean(Y_all,sample,demean_wave);
else % wave-specific analysis, always demean by "full" wave sample
    demean_wave = 0;  
    [Y_all] = Y_demean(Y_all,sample,demean_wave);
end

%% Define D, n, and ps as final inputs 
D = opower_wave;
n = length(Y_all); 
if wave==0
    ps = sample.ps;
else
    ps = table2array(sample(sample.opower_paper_wave==wave,{'ps'}));
end

%% Define X_all
switch covariate 
    case 'income' 
        X_all  = [inc_wave prevuse_wave];
    case 'size' 
        X_all  = [size_wave prevuse_wave];
    case 'vintage' 
        X_all  = [vintage_wave prevuse_wave];
    case "min"
        X_all = [min_baseline_wave prevuse_wave];
    case "max"
        X_all = [max_baseline_wave prevuse_wave];
    case "std"
        X_all = [std_baseline_wave prevuse_wave];
end 

end